package db

import (
	"context"
	"fmt"
	"sync"
	"time"

	logger "gitee.com/allan577/go-lib-logger"
	"github.com/go-sql-driver/mysql"
	"github.com/jinzhu/gorm"
	_ "github.com/jinzhu/gorm/dialects/postgres"
	"github.com/opentracing/opentracing-go"
)

// ----------------------------------------
//  GORM 客户端
// ----------------------------------------
type GORMClient struct {
	name        string
	config      *ORMConfig
	mu          *sync.Mutex
	client      *gorm.DB
	connectedAt time.Time
	pingedAt    time.Time
}

var gormCallbacksOnce = &sync.Once{}

// 创建 GORM 客户端实例
func CreateGORMClient(name string, config *ORMConfig) (*GORMClient, error) {
	client := &GORMClient{name: name, config: config.Copy(), mu: &sync.Mutex{}}
	if err := client.Connect(); err != nil {
		return nil, err
	} else {
		return client, nil
	}
}

// 获取当前 GORM 客户端实例名称
func (client *GORMClient) Name() string {
	return client.name
}

// 获取当前 GORM 客户端的第三方客户端实例
func (client *GORMClient) Client() *gorm.DB {
	return client.client
}

// 设置日志
func (client *GORMClient) SetLogger(log logger.Log) {
	client.client.SetLogger(log)
}

// 设置上下文
func (client *GORMClient) WithCtx(ctx context.Context) *gorm.DB {
	if ctx == nil {
		return client.client
	}
	parentSpan := opentracing.SpanFromContext(ctx)
	if parentSpan == nil {
		return client.client
	}
	gormCallbacksOnce.Do(func() {
		AddGormCallbacks(client.client)
	})
	return client.client.Set(parentSpanGormKey, parentSpan)
}

// 检查远程服务的可用性
func (client *GORMClient) Ping() error {
	if client.config.Ping && client.client != nil {
		client.pingedAt = time.Now().Local()
		return client.client.DB().Ping()
	} else {
		return nil
	}
}

// 连接远程服务
func (client *GORMClient) Connect() error {
	client.mu.Lock()
	defer client.mu.Unlock()
	if client.client != nil {
		return client.Ping()
	}
	var dsn string
	if client.config.Driver == "postgres" {
		dsn = fmt.Sprintf(
			"host=%s user=%s dbname=%s sslmode=disable password=%s",
			client.config.Host,
			client.config.Username,
			client.config.Database,
			client.config.Password)
	} else {
		c := mysql.NewConfig() // 用于构建 DSN
		c.Loc = time.Local
		c.User = client.config.Username
		c.Passwd = client.config.Password
		c.Net = "tcp"
		c.Addr = fmt.Sprintf("%s:%d", client.config.Host, client.config.Port)
		c.DBName = client.config.Database
		c.Collation = "utf8mb4_general_ci"
		c.ParseTime = true
		//使用mysql设置的MaxAllowedPacket，不设置的话默认4M
		c.MaxAllowedPacket = 0
		dsn = c.FormatDSN()
	}

	if db, err := gorm.Open(client.config.Driver, dsn); err != nil {
		return err
	} else {
		db.DB().SetMaxOpenConns(client.config.MaxConn)
		db.DB().SetMaxIdleConns(client.config.MaxIdleConn)
		// db.DB().SetConnMaxLifetime(time.Second * 300)
		client.client = db
		client.connectedAt = time.Now().Local()
		return client.Ping()
	}
}

// 关闭远程服务连接
func (client *GORMClient) Close() error {
	client.mu.Lock()
	defer client.mu.Unlock()
	if client.client == nil {
		return nil
	} else {
		err := client.client.Close()
		client.client = nil // 强制丢弃
		return err
	}
}

// 重连远程服务
func (client *GORMClient) Reconnect() error {
	_ = client.Close()
	return client.Connect()
}
